import torch
import torch.nn as nn


class PlasticityModel(nn.Module):

    def __init__(
        self,
        youngs_modulus_log: float = 6.0,
        poissons_ratio_unconstrained: float = -1.0,
        yield_stress: float = 2.5,
    ):
        """
        Plasticity model with logarithmic strain return mapping.

        Args:
            youngs_modulus_log (float): log Young's modulus.
            poissons_ratio_unconstrained (float): unconstrained scalar for Poisson's ratio.
            yield_stress (float): yield stress threshold.
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))  # scalar
        self.poissons_ratio_unconstrained = nn.Parameter(torch.tensor(poissons_ratio_unconstrained))  # scalar
        self.yield_stress = nn.Parameter(torch.tensor(yield_stress))  # scalar

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute corrected deformation gradient from deformation gradient tensor.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            F_corrected (torch.Tensor): corrected deformation gradient tensor (B, 3, 3).
        """
        youngs_modulus = self.youngs_modulus_log.exp()  # scalar
        poissons_ratio = torch.sigmoid(self.poissons_ratio_unconstrained) * 0.49  # scalar in (0, 0.49)
        yield_stress = self.yield_stress  # scalar

        mu = youngs_modulus / (2.0 * (1.0 + poissons_ratio))  # shear modulus μ

        # SVD: F = U Σ V^T
        U, sigma, Vh = torch.linalg.svd(F, full_matrices=False)  # U:(B,3,3), sigma:(B,3), Vh:(B,3,3)

        # Clamp singular values to avoid collapse
        sigma_clamped = torch.clamp_min(sigma, 1e-4)  # (B,3)

        # Logarithmic strain
        epsilon = torch.log(sigma_clamped)  # (B,3)

        # Volumetric strain (trace)
        epsilon_trace = epsilon.sum(dim=1, keepdim=True)  # (B,1)

        # Deviatoric strain
        epsilon_bar = epsilon - epsilon_trace / 3.0  # (B,3)

        # Norm of deviatoric strain (avoid division by zero)
        epsilon_bar_norm = torch.norm(epsilon_bar, dim=1, keepdim=True) + 1e-12  # (B,1)

        # Plastic multiplier
        delta_gamma = epsilon_bar_norm - yield_stress / (2.0 * mu)  # (B,1)

        # Plastic factor (clamped)
        plastic_factor = torch.clamp_min(delta_gamma / epsilon_bar_norm, 0.0)  # (B,1)

        # Correct logarithmic strain
        epsilon_corrected = epsilon - plastic_factor * epsilon_bar  # (B,3)

        # Reconstruct corrected singular values
        sigma_corrected = torch.exp(epsilon_corrected)  # (B,3)

        # Recompose corrected deformation gradient
        F_corrected = torch.matmul(U, torch.matmul(torch.diag_embed(sigma_corrected), Vh))  # (B,3,3)

        return F_corrected


class ElasticityModel(nn.Module):

    def __init__(
        self,
        youngs_modulus_log: float = 11.7,
        poissons_ratio_unconstrained: float = 5.5,
    ):
        """
        Corotated Elasticity model with trainable physical parameters.

        Args:
            youngs_modulus_log (float): log Young's modulus.
            poissons_ratio_unconstrained (float): unconstrained scalar for Poisson's ratio.
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))  # scalar
        self.poissons_ratio_unconstrained = nn.Parameter(torch.tensor(poissons_ratio_unconstrained))  # scalar

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute Kirchhoff stress tensor from deformation gradient tensor.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            kirchhoff_stress (torch.Tensor): Kirchhoff stress tensor (B, 3, 3).
        """
        youngs_modulus = self.youngs_modulus_log.exp()  # scalar
        poissons_ratio = torch.sigmoid(self.poissons_ratio_unconstrained) * 0.49  # scalar in (0, 0.49)

        mu = youngs_modulus / (2.0 * (1.0 + poissons_ratio))  # shear modulus μ
        la = youngs_modulus * poissons_ratio / ((1.0 + poissons_ratio) * (1.0 - 2.0 * poissons_ratio))  # lambda λ

        # SVD: F = U Σ V^T
        U, sigma, Vh = torch.linalg.svd(F, full_matrices=False)  # (B,3,3), (B,3), (B,3,3)

        # Clamp singular values for numerical stability
        sigma_clamped = torch.clamp_min(sigma, 1e-5)  # (B,3)

        # Rotation matrix R = U V^T
        R = torch.matmul(U, Vh)  # (B,3,3)

        Ft = F.transpose(1, 2)  # (B,3,3)

        # Corotated stress: 2 * mu * (F - R) * F^T
        corotated_stress = 2.0 * mu * torch.matmul(F - R, Ft)  # (B,3,3)

        # Compute determinant J = product of singular values
        J = torch.prod(sigma_clamped, dim=1)  # (B,)
        J = J.view(-1, 1, 1)  # (B,1,1)

        # Identity tensor I
        I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0)  # (1,3,3)

        # Volume stress: λ * J * (J - 1) * I
        volume_stress = la * J * (J - 1).view(-1, 1, 1) * I  # (B,3,3)

        # First Piola-Kirchhoff stress P
        P = corotated_stress + volume_stress  # (B,3,3)

        # Kirchhoff stress τ = P @ F^T
        kirchhoff_stress = torch.matmul(P, Ft)  # (B,3,3)

        return kirchhoff_stress
